# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from .op_parallel_config import default_dpmp_config

__all__ = [
    "MoEConfig"]


class MoEConfig:
    r"""
        The configuration of MoE (Mixture of Expert).

        Args:
            expert_num (int): The number of experts employed. Default: 1
            capacity_factor (float): The factor is used to indicate how much to expand expert capacity,
                which is >=1.0. Default: 1.1.
            aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the
                router) to be added to the entire model loss, which is < 1.0. Default: 0.05.
            num_experts_chosen (int): The number of experts is chosen by each token. Since only 'Top1' routing policy
                is supported currently, the value should be 1. Default: 1.
        Supported Platforms:
            ``Ascend`` ``GPU``

        Examples:
            >>> from mindspore.nn.transformer import MoEConfig
            >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1)
    """
    def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05,
                 num_experts_chosen=1):
        Validator.check_positive_int(expert_num, "expert_num")
        Validator.check_positive_float(capacity_factor, "capacity_factor")
        Validator.check_positive_float(aux_loss_factor, "aux_loss_factor")
        Validator.check_positive_int(num_experts_chosen, "num_experts_chosen")
        if capacity_factor < 1.0:
            raise ValueError(f"'capacity_factor' should be equal to or greater than 1.0, "
                             f"but got {capacity_factor}.")
        if aux_loss_factor >= 1.0:
            raise ValueError(f"'aux_loss_factor' should be less than 1.0, "
                             f"but got {aux_loss_factor}.")
        if num_experts_chosen != 1:
            raise ValueError(f"'num_experts_chosen' should be 1. Since only 'Top1' routing policy supported currently, "
                             f"the value should be 1.")
        self.expert_num = expert_num
        self.capacity_factor = capacity_factor
        self.aux_loss_factor = aux_loss_factor
        self.num_experts_chosen = num_experts_chosen


default_moe_config = MoEConfig()


def _check_moe_config(moe_config=None, parallel_config=None):
    if not isinstance(moe_config, MoEConfig):
        raise TypeError(f"'moe_config' should be an instance of MoEConfig, but got {type(moe_config).__name__}.")
    use_moe = (moe_config.expert_num > 1)
    if use_moe and moe_config.expert_num % parallel_config.data_parallel != 0:
        raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple "
                         f"of 'data_parallel' value in {type(parallel_config).__name__}, but got "
                         f"{moe_config.expert_num} for 'expert_num' and {parallel_config.data_parallel} for "
                         f"'data_parallel'.")


@constexpr
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim):
    return math.ceil(k * tokens_per_device * capacity_factor / expert_dim)


class MoE(Cell):
    """
    The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer.
    The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is
    obtained by multiplying FeedForward's output and router's combine weight.

    Args:
        hidden_size (int): The dimension of the inputs.
        ffn_hidden_size (int): The intermediate hidden size.
        dropout_rate (float): The dropout rate for the second linear's output.
        hidden_act (str): The activation of the internal feedforward layer. Supports 'relu',
                         'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
                         'hsigmoid', 'logsigmoid' and so on. Default: gelu.
        param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
        moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with
            default values. Please see `MoEConfig`.
        parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
                                           Default `default_dpmp_config`, an instance of `OpParallelConfig` with default
                                           args.

    Inputs:
        - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.

    Outputs:
        Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
    """
    def __init__(self, hidden_size,
                 ffn_hidden_size,
                 dropout_rate,
                 hidden_act='gelu',
                 param_init_type=mstype.float32,
                 moe_config=default_moe_config,
                 parallel_config=default_dpmp_config):
        super(MoE, self).__init__()
        self.hidden_size = hidden_size
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.aux_loss_factor = moe_config.aux_loss_factor
        self.num_experts_chosen = moe_config.num_experts_chosen
        self.expert_parallel = parallel_config.data_parallel
        self.dp = parallel_config.data_parallel
        from .transformer import FeedForward

        self.ffn = FeedForward(hidden_size=hidden_size,
                               ffn_hidden_size=ffn_hidden_size,
                               dropout_rate=dropout_rate,
                               hidden_act=hidden_act,
                               expert_num=self.expert_dim,
                               param_init_type=param_init_type,
                               parallel_config=parallel_config)
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.transpose_2dim = P.Transpose().shard(((self.dp, 1),))
        self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),))
        self.transpose_4dim = P.Transpose().shard(((self.dp, 1, 1, 1),))
        self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
        self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1)))
        self.mul = P.Mul().shard(((), ()))
        self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None,
                             training=True, parallel_config=parallel_config)
        self.cast = P.Cast()


    def construct(self, input_tensor):
        input_shape = F.shape(input_tensor)
        input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
        bs_and_dmodel = self.shape(input_tensor)
        tokens_per_device = bs_and_dmodel[0] // self.expert_parallel
        input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size))

        expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device,
                                                    self.capacity_factor, self.expert_dim)
        # dispatch_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity)
        # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity)
        dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor)

        # after transpose, input_tensor's shape: (self.expert_parallel, self.hidden_size, tokens_per_device)
        input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1))
        dispatch_tensor = self.reshape(dispatch_tensor, (self.expert_parallel, tokens_per_device,
                                                         self.expert_dim * expert_capacity))
        dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor))
        # expert_input's shape: (self.expert_parallel, self.hidden_size, self.expert_dim * expert_capacity)
        expert_input = self.batch_mm(input_tensor, dispatch_tensor)
        expert_input = self.reshape(expert_input, (self.expert_parallel, self.hidden_size, self.expert_dim,
                                                   expert_capacity))
        # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose
        # has bad performance
        expert_input = self.reshape(expert_input, (self.expert_parallel*self.hidden_size,
                                                   self.expert_dim*expert_capacity))
        expert_input = self.transpose_2dim(expert_input, (1, 0))
        expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.expert_parallel,
                                                   self.hidden_size))
        # expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
        expert_input = self.transpose_4dim(expert_input, (0, 2, 1, 3))
        expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity,
                                                   self.hidden_size))

        # expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
        expert_output = self.ffn(expert_input)
        expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel,
                                                     expert_capacity, self.hidden_size))
        # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose
        # has bad performance
        expert_output = self.reshape(expert_output, (self.expert_dim,
                                                     self.expert_parallel*expert_capacity*self.hidden_size))
        expert_output = self.transpose_2dim(expert_output, (1, 0))
        expert_output = self.reshape(expert_output, (self.expert_parallel, expert_capacity,
                                                     self.hidden_size*self.expert_dim))
        expert_output = self.transpose_3dim(expert_output, (0, 2, 1))
        # expert_output's shape: (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity)
        expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size, self.expert_dim,
                                                     expert_capacity))
        expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size,
                                                     self.expert_dim*expert_capacity))
        combine_tensor = self.reshape(combine_tensor, (self.expert_parallel, tokens_per_device,
                                                       self.expert_dim*expert_capacity))
        # combine_tensor's shape: (self.expert_parallel, self.expert_dim*expert_capacity, tokens_per_device)
        combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1))
        combine_tensor = self.cast(combine_tensor, F.dtype(expert_output))

        # combined_output's shape: (self.expert_parallel, self.hidden_size, tokens_per_device)
        combined_output = self.batch_mm2(expert_output, combine_tensor)
        # combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
        combined_output = self.transpose_3dim(combined_output, (0, 2, 1))
        combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
        combined_output = self.reshape(combined_output, input_shape)

        aux_loss = self.mul(self.aux_loss_factor, aux_loss)
        return combined_output, aux_loss


class Router(Cell):
    r"""
        A router backbone used to calculate logits of each token, which should be cascaded by router implementations
        mapping tokens to experts.

        Args:
            d_model (int): The hidden size of each token.
            moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
            routing_policy: The policy of mapping tokens to experts. Default: SwitchRouter
            training (bool): The value indicating whether is in training phase.
            parallel_config: The parallel-related configuration.
        Inputs:
            - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
            hidden\_size)`.

        Outputs:
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
    """

    def __init__(self,
                 d_model,
                 moe_config,
                 routing_policy=None,
                 training=True,
                 parallel_config=None):
        super(Router, self).__init__()
        dp = parallel_config.data_parallel
        self.d_model = d_model
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.training = training
        self.routing_policy = routing_policy
        self.noisy_policy = None  # candidate: ["jitter", "rsample", "None"]
        self.noisy_epsilon = 1e-2
        self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,)))

        self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False)
        self.dense.matmul.shard(((dp, 1), (1, 1)))
        self.mul = P.Mul().shard(((dp, 1, 1), (dp,)))
        self.cast = P.Cast()

        if self.routing_policy is None:
            self.router = SwitchRouter(d_model=d_model, moe_config=moe_config, training=training,
                                       parallel_config=parallel_config)
        else:
            self.router = routing_policy

    def construct(self, input_tensor):
        input_tensor = self.cast(input_tensor, mstype.float32)
        if self.noisy_policy == "jitter" and self.training:
            # Here, we temporarily implement the multiplicative jitter this way,
            # for the lack of UniforReal parallel operator.
            input_tensor = self.mul(input_tensor, self.noise)

        router_logits = self.dense(input_tensor)
        return self.router(router_logits)


class SwitchRouter(Cell):
    r"""
        A router implementation which maps each tokens to the top1 expert.
        Reference: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py

        Args:
            d_model (int): The hidden size of each token.
            moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
            training (bool): The value indicating whether is in training phase.
            config: The parallel-related configuration.
        Inputs:
            - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device,
            hidden\_size)`.

        Outputs:
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
            Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
            Tensor of shape :math:`(1)`.
    """

    def __init__(self,
                 d_model,
                 moe_config,
                 training=True,
                 parallel_config=None):
        super(SwitchRouter, self).__init__()
        dp = parallel_config.data_parallel
        self.d_model = d_model
        self.expert_dim = moe_config.expert_num
        self.capacity_factor = moe_config.capacity_factor
        self.training = training
        self.expert_parallel = dp
        self.noisy_policy = None
        self.cast = P.Cast()
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),))
        self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),))

        self.onehot = P.OneHot().shard(((dp, 1, 1), (), ()))
        self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ()))
        self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ()))
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)

        self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
        self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
        self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
        self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
        self.mul2 = P.Mul().shard(((1,), ()))
        self.mul3 = P.Mul().shard(((1,), ()))
        self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
        self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
        self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
        self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
        self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
        self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
        self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))

        self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),))
        self.less = P.Less().shard(((dp, 1, 1), ()))
        self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),))
        self.expand = P.ExpandDims().shard(((dp, 1),))
        self.expand2 = P.ExpandDims().shard(((dp, 1, 1),))

    def _auxiliary_loss(self, expert_mask, router_prob):
        """
        Computing the load balance loss.
        """
        # density_1's shape: (expert_parallel, self.expert_dim)
        density_1 = self.reduce_mean(expert_mask, 1)
        # density_1_proxy's shape: (expert_parallel, self.expert_dim)
        density_1_proxy = self.reduce_mean2(router_prob, 1)
        loss = self.mul(density_1, density_1_proxy)
        loss = self.reduce_mean3(loss)
        loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim)
        return loss

    def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate):
        """
        Keeping only the tokens that fit within expert_capacity.
        """
        cumsum = self.cumsum(expert_mask, 1)
        # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim)
        position_in_expert = self.mul4(cumsum, expert_mask)
        less_result = self.less(position_in_expert, expert_capacity)
        # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim)
        expert_mask = self.mul5(less_result, expert_mask)
        # expert_mask_flat's shape: (expert_parallel, tokens_per_device)
        expert_mask_flat = self.reduce_sum(expert_mask, -1)

        # Mask out the experts that have overflowed the expert_capacity.
        # expert_gate's shape: (expert_parallel, tokens_per_device)
        expert_gate = self.mul6(expert_gate, expert_mask_flat)
        return expert_gate, expert_mask_flat, position_in_expert

    def construct(self, router_logits):
        router_logits_shape = self.shape(router_logits)
        router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
        logits_shape = self.shape(router_logits)
        tokens_per_device = logits_shape[0] // self.expert_parallel
        expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim)
        router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim))
        # Currently, lack of gumbel sampler for router_logits.

        # Probabilities for each token of what expert is should be sent to
        router_prob = self.softmax(router_logits)
        # shape is : (expert_parallel, tokens_per_device)
        expert_index, expert_gate = self.argmax(router_prob)
        # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim)
        expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)

        # Computing the load balance loss:
        loss = self._auxiliary_loss(expert_mask, router_prob)

        expert_gate, expert_mask_flat, position_in_expert = \
            self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate)

        # combine_tensor's shape: (expert_parallel, tokens_per_device)
        combine_tensor = self.mul7(expert_gate, expert_mask_flat)
        # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim)
        combine_tensor = self.mul8(self.expand(combine_tensor, -1),
                                   self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
        # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity)
        combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
                                   self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
                                                self.on_value, self.off_value))
        # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has
        # bad performance
        dispatch_tensor = self.not_equal(combine_tensor, 0.0)
        return dispatch_tensor, combine_tensor, loss
